{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mOtlkKOJk11B"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "2CTPLx3Ukwrv"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wSIW5dYmdGfE"
      },
      "source": [
        "# Transformer Position Encoding\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/community/en/position_encoding.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/community/en/position_encoding.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZPoPd7GbdK0G"
      },
      "source": [
        "See [this notebook](https://github.com/site/en/r2/tutorials/sequences/transformer.ipynb) for a walk-through of full transformer implementation.\n",
        "\n",
        "The transformer architecture uses stacked attention layers in place of CNNs or RNNs. This makes it easy to learn long-range dependencise but it contains no built in information about the relative positions of items in a sequence. \n",
        "\n",
        "To give the model access to this information the transformer architecture uses adda a position encoding to the input.\n",
        "\n",
        "This endocing is a vector of sines and cosines at each position, where each sine-cosine pair rotates at a different frequency.  \n",
        "\n",
        "Nearby locations will have similar position-encoding vectors."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "XLC5-V6sefM6"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "y6BDOVCnKPSW"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oCajWknYcYpV"
      },
      "source": [
        "The angle rates range from `1 [rads/step]` to `min_rate [rads/step]` over the vector depth.\n",
        "\n",
        "Formula for angle rate:\n",
        "\n",
        "$$angle\\_rate_d = (min\\_rate)^{d / d_{max}} $$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ADsz4y85ANsB"
      },
      "outputs": [],
      "source": [
        "num_positions = 50\n",
        "depth = 512\n",
        "min_rate = 1/10000\n",
        "\n",
        "assert depth%2 == 0, \"Depth must be even.\"\n",
        "angle_rate_exponents = np.linspace(0,1,depth//2)\n",
        "angle_rates = min_rate**(angle_rate_exponents)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "E51Ne4cNe1i8"
      },
      "source": [
        "The resulting exponent goes from `0` to `1`, causing the `angle_rates` to drop exponentially from `1` to `min_rate`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ohldgW_Xg2sC"
      },
      "outputs": [],
      "source": [
        "plt.semilogy(angle_rates)\n",
        "plt.xlabel('Depth')\n",
        "plt.ylabel('Angle rate [rads/step]')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fMaA6pUqh-tQ"
      },
      "source": [
        "Broadcasting a multiply over angle rates and positions gives a map of the position encoding angles as a function of depth."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Rs2ABZ3XhYFg"
      },
      "outputs": [],
      "source": [
        "positions = np.arange(num_positions) \n",
        "angle_rads = (positions[:, np.newaxis])*angle_rates[np.newaxis, :]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "eHEHWQa8_51Z"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize = (14,8))\n",
        "plt.pcolormesh(\n",
        "    # Convert to degrees, and wrap around at 360\n",
        "    angle_rads*180/(2*np.pi) % 360,\n",
        "    # Use a cyclical colormap so that color(0) == color(360)\n",
        "    cmap='hsv', vmin=0, vmax=360)\n",
        "\n",
        "plt.xlim([0,len(angle_rates)])\n",
        "plt.ylabel('Position')\n",
        "plt.xlabel('Depth')\n",
        "bar = plt.colorbar(label='Angle [deg]')\n",
        "bar.set_ticks(np.linspace(0,360,6+1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MEGyhu_JiNDS"
      },
      "source": [
        "Raw angles are not a good model input (they're either unbounded, or discontinuous). So take the sine and cosine:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "BBnQuSSPGl8t"
      },
      "outputs": [],
      "source": [
        "sines = np.sin(angle_rads)\n",
        "cosines = np.cos(angle_rads)\n",
        "pos_encoding = np.concatenate([sines, cosines], axis=-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "UQrE2_nPHjXC"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(14,8))\n",
        "plt.pcolormesh(pos_encoding, \n",
        "               # Use a diverging colormap so it's clear where zero is.\n",
        "               cmap='RdBu', vmin=-1, vmax=1)\n",
        "plt.xlim([0,depth])\n",
        "plt.ylabel('Position')\n",
        "plt.xlabel('Depth')\n",
        "plt.colorbar()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "e536lUb-k-UQ"
      },
      "source": [
        "## Nearby positions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UzWWHKK1lu0z"
      },
      "source": [
        "Nearby locations will have similar position-encoding vectors. \n",
        "\n",
        "To demonstrate compare one position's encoding (here position 20) with each of the others:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kruiT670zjrZ"
      },
      "outputs": [],
      "source": [
        "pos_encoding_at_20 = pos_encoding[20]\n",
        "\n",
        "dots = np.dot(pos_encoding,pos_encoding_at_20)\n",
        "SSE = np.sum((pos_encoding - pos_encoding_at_20)**2, axis=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xH-sU6mzmjXI"
      },
      "source": [
        "Regardless of how you compare the vecors, they are most similar 20, and clearly diverge as you move away: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "YQ-MqQ5Eo0ZT"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10,8))\n",
        "plt.subplot(2,1,1)\n",
        "plt.plot(dots)\n",
        "plt.ylabel('Dot product')\n",
        "plt.subplot(2,1,2)\n",
        "plt.plot(SSE)\n",
        "plt.ylabel('SSE')\n",
        "plt.xlabel('Position')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BVgAzBJ7iYda"
      },
      "source": [
        "## Relative positions\n",
        "\n",
        "The [paper](https://arxiv.org/pdf/1706.03762.pdf) explains, at the end of section 3.5, that any relative position encoding can be written as a linear function of the current position.\n",
        "\n",
        "To demonstrate, this section builds a matrix that calculates these relative position encodings. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "aHOBuNYnLHFF"
      },
      "outputs": [],
      "source": [
        "def transition_matrix(position_delta, angle_rates = angle_rates):\n",
        "  # Implement as a matrix multiply:\n",
        "  #    sin(a+b) = sin(a)*cos(b)+cos(a)*sin(b)\n",
        "  #    cos(a+b) = cos(a)*cos(b)-sin(a)*sin(b)\n",
        "  \n",
        "  # b\n",
        "  angle_delta = position_delta*angle_rates\n",
        "\n",
        "  # sin(b), cos(b)\n",
        "  sin_delta = np.sin(angle_delta)\n",
        "  cos_delta = np.cos(angle_delta)\n",
        "\n",
        "  I = np.eye(len(angle_rates))\n",
        "  \n",
        "  # sin(a+b) = sin(a)*cos(b)+cos(a)*sin(b)\n",
        "  update_sin = np.concatenate([I*cos_delta, I*sin_delta], axis=0)\n",
        "  \n",
        "  # cos(a+b) = cos(a)*cos(b)-sin(a)*sin(b)\n",
        "  update_cos = np.concatenate([-I*sin_delta, I*cos_delta], axis=0)\n",
        "\n",
        "  return np.concatenate([update_sin, update_cos], axis=-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GZKHXFAxkMb_"
      },
      "source": [
        "For example, create the matrix that calculates the position encoding 10 steps back, from the current position encoding:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "x31yfQRQM6Q6"
      },
      "outputs": [],
      "source": [
        "position_delta = -10\n",
        "update = transition_matrix(position_delta)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NznTFtWmklob"
      },
      "source": [
        "Applying this matrix to each position encoding vector gives position encoding vector from -10 steps away, resulting in a shifted position-encoding map:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "hE7t9frqNQV3"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(14,8))\n",
        "plt.pcolormesh(np.dot(pos_encoding,update), cmap='RdBu', vmin=-1, vmax=1)\n",
        "plt.xlim([0,depth])\n",
        "plt.ylabel('Position')\n",
        "plt.xlabel('Depth')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "a1IHJT2Jk16a"
      },
      "source": [
        "This is accurate to numerical precision."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tRmt2d-HQD8j"
      },
      "outputs": [],
      "source": [
        "errors = np.dot(pos_encoding,update)[10:] - pos_encoding[:-10]\n",
        "abs(errors).max()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "position_encoding.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
